# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import uuid

import config
import numpy as np
import utils
from numpy.random import randint, randn

import paddle
from paddle.incubate.autograd import primops

paddle.enable_static()


@utils.place(config.DEVICES)
@utils.parameterize(
    (
        utils.TEST_CASE_NAME,
        'op',
        'args',
        'kwargs',
        'expected_shape',
        'expected_dtype',
    ),
    (
        ('add', primops.add, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
        ('sub', primops.sub, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
        ('mul', primops.mul, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
        ('div', primops.div, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
        ('sub', primops.sub, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
        ('sqrt', primops.sqrt, randn(2, 3), {}, (2, 3), 'float64'),
        ('tanh', primops.tanh, randn(2, 3), {}, (2, 3), 'float64'),
        ('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'),
        ('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'),
        ('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'),
        ('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'),
        ('abs', primops.abs, randn(2, 3), {}, (2, 3), 'float64'),
        ('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'),
        (
            'cast',
            primops.cast,
            randn(2, 3),
            {'dtype': paddle.int64},
            (2, 3),
            'int64',
        ),
        (
            'reshape',
            primops.reshape,
            randn(2, 3),
            {'shape': (3, 2)},
            (3, 2),
            'float64',
        ),
        (
            'broadcast',
            primops.broadcast,
            randn(2),
            {'shape': (3, 2)},
            (3, 2),
            'float64',
        ),
        (
            'transpose',
            primops.transpose,
            randn(2, 3),
            {'axis': (1, 0)},
            (3, 2),
            'float64',
        ),
        (
            'concat_axis0',
            primops.concat,
            ((randn(2, 3), randn(2, 3)),),
            {'axis': 0},
            (4, 3),
            'float64',
        ),
        (
            'concat_axis1',
            primops.concat,
            ((randn(2, 3), randn(2, 3)),),
            {'axis': 1},
            (2, 6),
            'float64',
        ),
        (
            'reduce_axis1',
            primops.reduce_sum,
            randn(2, 3),
            {'axis': (1,)},
            (2,),
            'float64',
        ),
        (
            'reduce_axis01',
            primops.reduce_sum,
            randn(2, 3),
            {'axis': (0, 1)},
            (1,),
            'float64',
        ),
        (
            'split',
            primops.split,
            randn(2, 3),
            {'num_or_sections': [1, 2], 'axis': 1},
            ((2, 1), (2, 2)),
            ('float64', 'float64'),
        ),
        (
            'matmul',
            primops.matmul,
            (randn(2, 3), randn(3, 2)),
            {},
            (2, 2),
            'float64',
        ),
        (
            'slice_select',
            primops.slice_select,
            randn(3, 2),
            {'axis': [0], 'starts': [0], 'ends': [2], 'strides': [1]},
            (2, 2),
            'float64',
        ),
        (
            'slice_assign',
            primops.slice_assign,
            (randn(2, 3), randn(2, 2)),
            {'axis': [1], 'starts': [1], 'ends': [3], 'strides': [1]},
            (2, 3),
            'float64',
        ),
        (
            'gather',
            primops.gather,
            (randn(3, 2), randint(0, 2, (5,), np.int32)),
            {'axis': 0},
            (5, 2),
            'float64',
        ),
        (
            'scatter_add',
            primops.scatter_add,
            (randn(3, 2), randn(5, 2), randint(0, 2, (5,), np.int32)),
            {'axis': 0},
            (3, 2),
            'float64',
        ),
        (
            'fill_const',
            primops.fill_const,
            (),
            {'value': 10, 'shape': (3, 2), 'dtype': paddle.float32},
            (3, 2),
            'float32',
        ),
        ('neg', primops.neg, randn(2, 3), {}, (2, 3), 'float64'),
        (
            'select',
            primops.select,
            (randn(2, 3) > 0, randn(2, 3), randn(2, 3)),
            {},
            (2, 3),
            'float64',
        ),
        ('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
        ('ne', primops.ne, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
        ('gt', primops.gt, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
        ('ge', primops.ge, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
        ('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
        ('max', primops.max, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
    ),
)
class TestPrimops(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        paddle.enable_static()

    @classmethod
    def tearDownClass(cls):
        paddle.disable_static()

    def test_prim_ops(self):
        program = paddle.static.Program()
        with paddle.static.program_guard(program):
            args = self._as_tuple(self.args)
            args = self.arr2var(args)
            results = self.op(*args, **self.kwargs)
            results = self._as_tuple(results)
            expected_shape = self._as_tuple(self.expected_shape)
            expected_dtype = self._as_tuple(self.expected_dtype)

            for r, shape, dtype in zip(results, expected_shape, expected_dtype):
                self.assertEqual(r.shape, shape)
                self.assertEqual(str(r.dtype).split('.')[1], dtype)

    def arr2var(self, arr):
        """convert numpy ndarray to paddle Variable recursively."""
        return [
            paddle.static.data(f'x{uuid.uuid4()}', v.shape, v.dtype)
            if isinstance(v, np.ndarray)
            else self.arr2var(v)
            for v in arr
        ]

    def _as_tuple(self, input):
        if isinstance(input, (tuple, list)) and len(input) == 0:
            return input
        if not isinstance(input, (tuple, list)) or all(
            isinstance(i, int) for i in input
        ):
            return (input,)
        return input


if __name__ == '__main__':
    unittest.main()
